# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train model on ascend device."""
import argparse
import os
import numpy as np
from mindspore import context
from mindspore.profiler import Profiler

from train import train

np.random.seed(74)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MindSpore Bert Training Example')
    # This argument is used for ModelArts to pass in data folder url in obs.
    parser.add_argument('--data_url', type=str, default="./data", help='Dataset path in OBS.')
    # This argument is used for ModelArts to pass in output folder url in obs.
    parser.add_argument('--train_url', type=str, default="./output", help='Training output data path in OBS.')
    parser.add_argument('--summary_folder', type=str, default="summary", help='Summary folder name.')
    parser.add_argument('--use_profiler', type=str, default="false", help='Whether to enable profiler.')
    parser.add_argument('--test_sentence', type=str, default=None, help="Sentence to be used for testing.")
    args = parser.parse_args()

    LOCAL_DATA_PATH = args.data_url
    LOCAL_OUTPUT_PATH = args.train_url
    # Summary directory.
    SUMMARY_DIR = os.path.join(LOCAL_OUTPUT_PATH, args.summary_folder)
    # Training dataset path.
    TRAIN_DATASET = os.path.join(LOCAL_DATA_PATH, "train.mindrecord")
    # Testing dataset path.
    TEST_DATASET = os.path.join(LOCAL_DATA_PATH, "test.mindrecord")
    # Pre-trained model ckpt path.
    PRETRAINED_MODEL_CKPT_PATH = os.path.join(LOCAL_DATA_PATH, "bert_zh.ckpt")

    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

    if args.use_profiler != "false":
        profiler = Profiler(output_path=SUMMARY_DIR)
    train(TRAIN_DATASET, TEST_DATASET, PRETRAINED_MODEL_CKPT_PATH, SUMMARY_DIR, args.test_sentence)
    if args.use_profiler != "false":
        profiler.analyse()
